#This module contains a user-defined function that simulates the 1D Richards equation. It is the environment, so to speak. Here, the volumetric moisture content is used as the state
#by Bernard Twum Agyeman (agyeman@ualberta.ca)

from Model1D_vectorized import * 
from Parameters_1D import * 
from scipy import integrate
from van_Genuchten import pressureHead, volumetricMoisture

soilPars = pars_mz3()
totalDepth, axialNodes, totalNodes = spatialVariables_1D_act()   #Spatial parameters
def ODE(head, timeArray, irrigAmount, refEvap, cropCoeff, rooting_depth, lai_factor):
    return Richards1D_vectorized(head, irrigAmount, refEvap, cropCoeff, rooting_depth, lai_factor)

def simulate_Richards_equation(initial_condition, irrigation_rates_daily, reference_evap_daily, crop_coeff_daily, time_steps_daily, rooting_depth, lai_factor):
    samplingTime, samplingTimeInternal, internalTimeSteps, totalTimeSteps_daily, totalTimeSteps = temporal_discretization(
        time_steps_daily)
    head_array = np.zeros((totalTimeSteps+1, totalNodes))
    volMoisture = np.zeros((totalTimeSteps+1, totalNodes))
    irrigation_rates = np.zeros(totalTimeSteps)
    evapotranspiration = np.zeros(totalTimeSteps)
    crop_coeff = np.zeros(totalTimeSteps)
    head_array[0] = initial_condition
    volMoisture[0] = volMoistureAllNodes_1D(initial_condition)

    for i in range(1):
        irrigation_rates[int(i*time_steps_daily)] = irrigation_rates_daily*time_steps_daily
        pass

    for i in range(1):
        crop_coeff[int(i*totalTimeSteps_daily):int((i+1)*totalTimeSteps_daily)
                   ] = crop_coeff_daily*np.ones(totalTimeSteps_daily)
        evapotranspiration[int(i*totalTimeSteps_daily):int((i+1)*totalTimeSteps_daily)] = (
            reference_evap_daily*1)*np.ones(totalTimeSteps_daily)
        pass

    for i in range(1, totalTimeSteps+1):
        sol = integrate.solve_ivp(fun=lambda t, y: ODE(y, t, irrigation_rates[i-1], crop_coeff[i-1], evapotranspiration[i-1], rooting_depth, lai_factor), t_span=[0, samplingTimeInternal],
                                  y0=tuple(head_array[i-1]), method='LSODA')
        head_array[i, :] = sol.y[:, -1]
        volMoisture[i, :] = volMoistureAllNodes_1D(head_array[i, :]).ravel()
        pass
    return head_array[-1, :]


def SimulateModel(state, action, refEvap, cropCoeff, rooting_depth, lai_factor):
    # Perform the simulation and return the last values
    # Convert volumetric moisture content to pressure head
    state_head = pressureHead(state, soilPars)
    irrig_rate = (action/86400)
    nextState_head = simulate_Richards_equation(state_head, irrig_rate, refEvap/(86400*1000), cropCoeff, 4, rooting_depth, lai_factor)
    # Convert pressure head to volumetric soil moisture
    nextState_vol_moist = volumetricMoisture(nextState_head, soilPars)
    # Randomly generate the ET and the Kc and add them to the state array
    refEvap_gen   = np.random.uniform(1.04, 9.0)
    cropCoeff_gen = np.random.uniform(0.20, 1.25)
    rooting_depth = np.random.choice([0.50, 1.00])
    lai_factor  = np.random.uniform(0, 1.0)
    agent_state = np.zeros(totalNodes+4)
    agent_state[0:totalNodes] = nextState_vol_moist
    agent_state[totalNodes]   = refEvap_gen
    agent_state[totalNodes+1] = cropCoeff_gen
    agent_state[totalNodes+2] = rooting_depth
    agent_state[totalNodes+3] = lai_factor
    return agent_state